
if __name__ == '__main__':
    import pandas as pd
    import numpy as np
    import itertools
    try:
        from src.interpolation.splines import UCGrid
    except:
        from ..interpolation.splines import UCGrid
    import sys

    sys.path.append('./')

    from src.models_new import surplus
    from src.models_new import construct_data

    #%% SET THE PARAMETERS --------------------------------------------------------------
    mri_grid = 5
    g_grid = 10
    n_grid = 5
    seeds = list(range(200))
    max_sim_length = 120
    beta_by_time = {'month': 0.99, 'fortnight': 0.995}
    options = {
        'threads_per_worker': 1,
        'n_workers': 7
    }

    #%% READ IN THE INPUTS FROM EARLIER STAGE -------------------------------------------
    for t in ['month', 'fortnight']:
        r = pd.read_csv(f'./models/first_stage/r_{t}.csv', index_col=[0])
        const = pd.read_csv(f'./models/first_stage/const_{t}.csv', index_col=[0])
        sigma = pd.read_csv(f'./models/first_stage/sigma_{t}.csv', index_col=[0])
        df_extensions_with_mri = pd.read_csv(
            f'./models/first_stage/first_stage_extensions_with_mri_{t}.csv', index_col=[0])

        search_grid_params_by_spec = dict()
        search_grid_by_spec = dict()
        search_value_by_spec = dict()
        prob_extend_by_spec = dict()
        for spec in ['low', 'mid', 'high']:
            search_grid_params_by_spec[spec] = np.load(f'./models/value_search/search_grid_{spec}_{t}.npy')
            search_grid_by_spec[spec] = UCGrid(
                tuple(search_grid_params_by_spec[spec][0, :]),
                tuple(search_grid_params_by_spec[spec][1, :]),
                tuple(search_grid_params_by_spec[spec][2, :]),
                tuple(search_grid_params_by_spec[spec][3, :])
            )
            search_value_by_spec[spec] = np.load(f'./models/value_search/search_value_{spec}_{t}.npy')
            prob_extend_by_spec[spec] = construct_data.build_prob_extension_predictor_with_mri(
                coefs=df_extensions_with_mri.loc[spec].values)

        #%% COMPUTE THE SURPLUS COMPONENTS AND SAVE -------------------------------------
        if t == "month":
            period_multiplier = 1
        elif t == "fortnight":
            period_multiplier = 2

        match_values_by_tau_spec = dict()
        for tau, spec in itertools.product([2, 3, 4], ['low', 'mid', 'high']):
            print(f'Getting surplus for: contract length: {tau}; rig type: {spec}')
            (
                grid_surplus,
                nodes_grid,
                nodes_list,
                match_values_by_tau_spec[(tau, spec)]
            ) = surplus.init_fast_surplus(
                mri_grid=mri_grid,
                g_grid=g_grid,
                n_grid=n_grid,
                tau=tau * period_multiplier,
                spec=spec,
                grid=search_grid_by_spec[spec],
                values=search_value_by_spec[spec],
                prob_extend=prob_extend_by_spec[spec],
                const=const.values.T[0],
                r=r.values,
                sigma=sigma.values[0][0],
                seeds=seeds,
                max_sim_length=max_sim_length,
                beta=beta_by_time[t],
                options=options
            )

            np.savetxt(
                f'./models/surplus/surplus_components_{tau}_{spec}_{t}.txt',
                match_values_by_tau_spec[(tau, spec)]
            )
            np.save(
                f'./models/surplus/surplus_grid_{tau}_{spec}_{t}.npy',
                grid_surplus
            )

        np.savetxt(
            f'./models/surplus/nodes_{t}.txt',
            np.array(nodes_list)
        )
